/*
    Copyright 2006-2011 Patrik Jonsson, sunrise@familjenjonsson.org

    This file is part of Sunrise.

    Sunrise is free software; you can redistribute it and/or modify
    it under the terms of the GNU General Public License as published by
    the Free Software Foundation; either version 3 of the License, or
    (at your option) any later version.

    Sunrise is distributed in the hope that it will be useful,
    but WITHOUT ANY WARRANTY; without even the implied warranty of
    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    GNU General Public License for more details.

    You should have received a copy of the GNU General Public License
    along with Sunrise.  If not, see <http://www.gnu.org/licenses/>.

*/

/// \file
/// Contains classes related to emission of rays.

#ifndef __emission_fits__
#define __emission_fits__

#include "emission.h"
#include "aux_grid.h"

#include "blitz-fits.h"
#include "CCfits/CCfits"
#include "mcrx-debug.h"
#include "boost/thread/mutex.hpp"
#include "boost/thread/thread.hpp"
#include "boost/lambda/lambda.hpp"
#include "functors.h"

namespace mcrx {
  template <template<typename> class, typename> 
  class full_sed_particles_emission;
  template <template<typename> class, typename> 
  class aux_particles_emission;
}


/** Emission class for polychromatic runs that loads a set of
    particles from FITS table data. These particles are treated as
    entries in an emission_collection. */
template <template<typename> class sampling_policy, 
	  typename rng_policy_type>
class mcrx::full_sed_particles_emission : 
  public mcrx::emission_collection<polychromatic_policy, 
				   sampling_policy, rng_policy_type> {
public:
  typedef typename emission_collection<polychromatic_policy,
				       sampling_policy, 
				       rng_policy_type>::T_lambda T_lambda;
  typedef typename emission_collection<polychromatic_policy, sampling_policy, rng_policy_type>::T_biaser T_biaser;
  typedef rng_policy_type T_rng_policy;
  typedef particle_emission<polychromatic_policy, rng_policy_type> T_emitter;
  typedef boost::shared_ptr<emission<polychromatic_policy,
				     rng_policy_type> > T_emitter_ptr;

private:
  full_sed_particles_emission (full_sed_particles_emission&);
  full_sed_particles_emission& operator=(full_sed_particles_emission&);
  
  /** Vector containing the actual particle emitters. They are kept
      with shared_ptrs so they don't need to be destroyed. */
  std::vector<boost::shared_ptr<emission<polychromatic_policy, 
					 rng_policy_type> > > emitters_;
  /** Array containing the emitter luminosities. Indices are
      (particle, lambda). */
  array_2 L_lambda_;

  /// Units 
  T_unit_map units_;

public:
  /** Constructor loads emission data from the specified HDU. If urw
      (use_reference_weights) is true, the emission weight will be the
      luminosity af the reference wavelength, otherwise it will be the
      bolometric luminosity. If a wavelength vector is supplied in
      subsample_to, the emission SEDs are subsampled to that vector
      when loading. */
  template<typename T_grid>
  full_sed_particles_emission (CCfits::ExtHDU&, T_grid* grid,
			       const array_1& lambda=array_1(),
			       const array_1& subsample_to=array_1());

  /** Copy constructor makes a reference copy of the L_lambda array. */
  full_sed_particles_emission (const full_sed_particles_emission& rhs) : 
    emission_collection<polychromatic_policy, sampling_policy, 
			rng_policy_type>(rhs), 
    emitters_(rhs.emitters_), L_lambda_(rhs.L_lambda_) {};

  const T_unit_map& units() const {return units_; };

  // The default copy constructor is used, which copies L_lambda_ by
  // reference, and the emitters_ vector by value but those are
  // shared_ptrs, so it effectively makes a shallow copy, but with our
  // own distribution. This means each thread can make a local copy so
  // we don't have to lock the weights all the time.

  /** The emit function adjusts the weights to be the luminosities at
      the wavelength indicated by the biaser object. To maintain
      energy conservation, the emitted spectrum is actually the bias
      factor of the intensities of the selected emitter multiplied by
      the total luminosity at the reference wavelength,
      i.e. L_i(lambda)/L_i(reflambda)*L_tot(reflambda). The
      probability of drawing emitter i is
      L_i(reflambda)/L_tot(reflambda) so multiplying intensity and
      probability gives L_i(lambda) which is what we want. */
  virtual ray_distribution_pair<T_lambda> emit (T_biaser b,
						T_rng_policy& rng, 
						T_float& prob) const {
    
    T_float temp_prob1, temp_prob2;
    ray_distribution_pair<T_lambda> temp =
      this->distr_.sample(rng, temp_prob1)->emit(b, rng, temp_prob2);
    prob = temp_prob1*temp_prob2;

    DEBUG(2,cout << "Emission max L: " << max(temp.normalization()) << endl;);

    // the normalization of the emitters was done once and for all in
    // the constructor
    temp.normalization() *= this->distr_.total_weight();

    assert(max(temp.normalization())<1e60);
    DEBUG(3,cout << "Emitting Ratio " << temp.normalization()(535)/temp.normalization()(532) << ' '<< temp.normalization()(615)/temp.normalization()(614) << endl;);

    return temp;
  };

  // emission_weight is inherited

  boost::shared_ptr<emission<polychromatic_policy, rng_policy_type> > clone() const {
    return boost::shared_ptr<emission<polychromatic_policy, rng_policy_type> >(new full_sed_particles_emission(*this));};

};


/** Emission class that loads a set of particles containing auxiliary
    data from FITS table data. */
template <template<typename> class sampling_policy, 
	  typename rng_policy_type>
class mcrx::aux_particles_emission : 
  public emission_collection<generic_chromatic_policy<aux_pars_type>,
			     sampling_policy, rng_policy_type> 
{
public:
  typedef typename emission_collection<generic_chromatic_policy<aux_pars_type>,
				       sampling_policy, 
				       rng_policy_type>::T_lambda T_content;
  typedef rng_policy_type T_rng_policy;
  typedef particle_emission<generic_chromatic_policy<aux_pars_type>,
			    rng_policy_type> T_emitter;
  typedef boost::shared_ptr<emission<generic_chromatic_policy<aux_pars_type>,
				     rng_policy_type> > T_emitter_ptr;

private:
  aux_particles_emission (aux_particles_emission&);
  aux_particles_emission& operator=(aux_particles_emission&);
  
public:
  aux_particles_emission (CCfits::ExtHDU&);
  aux_particles_emission (const aux_particles_emission& rhs) :
    emission_collection<generic_chromatic_policy<aux_pars_type>,
			sampling_policy, rng_policy_type> (rhs) {};

  void fucked_up(CCfits::ExtHDU& hdu);

  // emit and emission_weight are inherited

  boost::shared_ptr<emission<generic_chromatic_policy<aux_pars_type>, rng_policy_type> > clone() const {
    return boost::shared_ptr<emission<generic_chromatic_policy<aux_pars_type>, rng_policy_type> >(new aux_particles_emission(*this));};

};


template <template<typename> class sampling_policy, 
	  typename rng_policy_type>
template <typename T_grid>
mcrx::full_sed_particles_emission<sampling_policy, rng_policy_type>::
full_sed_particles_emission(CCfits::ExtHDU& hdu,
			    T_grid* grid,
			    const array_1& lambda,
			    const array_1& subsample_to) : 
  emission_collection<polychromatic_policy, sampling_policy, rng_policy_type>()
{
  using namespace CCfits;
  using namespace std;
  using namespace blitz;

  std::cout << "Using L_bol as emission weight." << std::endl;

  Column& c_position = hdu.column("position" );
  Column& c_velocity = hdu.column("velocity" );
  Column& c_s_radius = hdu.column("radius" );
  Column& c_l_lambda = hdu.column("L_lambda");
  Column& c_l_bol = hdu.column("L_bol" );

  units_["length"] = c_position.unit();
  units_["luminosity"] = c_l_bol.unit();
  units_["L_lambda"] = c_l_lambda.unit();
  units_["velocity"] = c_velocity.unit();

  vector<vec3d> position, velocity;
  vector<T_float> radius;
  array_1 L_bol;

  read (c_position, position);
  read (c_velocity, velocity);
  c_s_radius.read(radius, 1, c_s_radius.rows() );
  read (c_l_lambda, L_lambda_);
  cout << "Allocated a (" << L_lambda_.extent(firstDim) << ", " << L_lambda_.extent(secondDim)
       << ") memory block for particle emission, "
       << L_lambda_.size()*sizeof(T_float)*1.0/(1024*1024*1024) << " GB.\n";
  read (c_l_bol, L_bol);

  // if a grid was specified, ask it which positions we own
  if(grid) {
    vector<bool> own_loc = grid->position_ownership(position);
    assert(own_loc.size()==position.size());
    int nown2=0;
    for(int i=0; i<own_loc.size(); ++i) if(own_loc[i]) nown2++;

    const size_t nown=std::count_if(own_loc.begin(), own_loc.end(), 
				   boost::lambda::_1);
    printf("Owning %ld/%ld emitters\n",nown,position.size());
    // cut out non-owned entries from the arrays
    position.erase(remove_if(position.begin(), position.end(), 
			     functors::boolvector_notpredicate(own_loc.begin())),
		   position.end());
    velocity.erase(remove_if(velocity.begin(), velocity.end(), 
			     functors::boolvector_notpredicate(own_loc.begin())),
		   velocity.end());
    radius.erase(remove_if(radius.begin(), radius.end(), 
			     functors::boolvector_notpredicate(own_loc.begin())),
		   radius.end());
    // for the arrays, this is not quite as easy
    remove_if(L_bol.begin(), L_bol.end(), 
	      functors::boolvector_notpredicate(own_loc.begin()));
    L_bol.resizeAndPreserve(nown);
    for(int i=0,j=0; i<own_loc.size(); ++i) {
      if(own_loc[i])
	L_lambda_(j++,Range::all()) = L_lambda_(i,Range::all());
    }
    L_lambda_.resizeAndPreserve(nown,L_lambda_.extent(secondDim));
  
    assert(position.size()==nown);
    assert(velocity.size()==nown);
    assert(radius.size()==nown);
  }
  
  // look for logarithmic flux
  try {
    bool logarithmic= false;
    hdu.readKey("logflux", logarithmic);
    if (logarithmic)
      L_lambda_ = pow (10., L_lambda_);
  }
  catch (HDU::NoSuchKeyword&) {}

  ASSERT_ALL(L_lambda_>=0);

  // need velocity in m/s because that's what we have c0 in.
  const T_float velcon = 
    units::convert(units_.get("velocity"), "m/s");

  // see if we must subsample
  if(lambda.size()>0) {
    cout << "Subsampling emission SEDs" << endl;
    assert(subsample_to.size()>0);
    const size_t nl=lambda.size();
    const size_t nc=L_lambda_.extent(firstDim);
    assert(nl==L_lambda_.extent(secondDim));

    array_2 new_L_lambda(nc, subsample_to.size());
    for(size_t c=0; c<nc; ++c)
      new_L_lambda(c,Range::all())=subsample(L_lambda_(c,Range::all()),
					     lambda, subsample_to);
    // and swap out the old one
    L_lambda_.reference(new_L_lambda);
  }

  // This vector is used to collect the pointers to the emitters for
  // the call to distr_.setup.
  std::vector<emission<polychromatic_policy, rng_policy_type>*> pointers;

  // if we are using bolometric weights, divide by bolometric now
  L_lambda_ = L_lambda_(tensor::i, tensor::j)/L_bol(tensor::i);

  // generate particles
  for (int i = 0; i < radius.size(); ++i) {
      // The normalization here is L_lambda/L_bol, but if we are using
      // the reference_wavelength luminosity for emission weight, the
      // emitted spectrum is normalized by the weight in each emit call.

      // Using weakReference is necessary to avoid getting swamped by
      // the array reference count mutex.
      array_1 temp_lum;
      temp_lum.weakReference(L_lambda_ (i, Range::all ()));
      assert (temp_lum.isStorageContiguous());
      emitters_.push_back(T_emitter_ptr(new T_emitter(position[i], 
						      velocity[i]*velcon,
						      radius[i],
						      temp_lum)));
      DEBUG(3,cout << "Emitter velocity " << vec3d(velocity[i]*velcon) << " in full_sed_particles_emission" << endl;);
      // The pointers are shared_ptrs, so that ensures we don't have
      // problems making copies of *this, they will all point to the
      // same emitters and will be destroyed appropriately.
      pointers.push_back(emitters_.back().get());
    }

  // Now set up probability distribution object, setting L_bol as the
  // emission weight. (This is changed if use_reference_weights_ is
  // true, then the emission weight is changed in the emit call.)
  this->distr_.setup(pointers, L_bol);

  DEBUG(2,cout << "Max L in all emitters: " << max(L_lambda_) << endl;);
}


/** This function is called from the constructor since gdb can't debug
    constructors. */
template <template<typename> class sampling_policy, 
	  typename rng_policy_type>
void 
mcrx::aux_particles_emission<sampling_policy, rng_policy_type>::
fucked_up(CCfits::ExtHDU& hdu)
{
  using namespace CCfits;
  using namespace std;
  using namespace blitz;

  Column& c_position = hdu.column("position" );
  Column& c_velocity = hdu.column("velocity" );
  Column& c_s_radius = hdu.column("radius" );
  Column& c_l_bol = hdu.column("L_bol" );
  Column& c_m_s = hdu.column("mass" );
  Column& c_z = hdu.column("metallicity" );
  Column& c_age = hdu.column("age");

  array_2 position;
  array_2 velocity;
  vector<T_float> radius;
  vector<T_float> L_bol;
  vector<T_float> mass_stars;
  vector<T_float> z;
  vector<T_float> age;

  const int n= c_s_radius.rows();
  read (c_position, position);
  read (c_velocity, velocity);
  c_s_radius.read(radius, 1, n );
  c_l_bol.read(L_bol, 1, n );
  c_m_s.read(mass_stars, 1, n);
  c_z.read(z, 1, n);
  c_age.read(age, 1, n);
  
  // This vector is used to collect the pointers to the emitters for
  // the call to distr_.setup.
  std::vector<emission<generic_chromatic_policy<T_content>,
    rng_policy_type>*> pointers;
  array_1 weights(n);

  // generate particles
  for (int i = 0; i < n; ++i) {
    T_content d(0.);
    T_float emission_weight;
    if (age[i]!=age[i]) {
      // This is (better be) a black hole.  Set all mass fields to
      // zero but include it in the luminosity. (This very slightly
      // screws up the luminosity-weighted age because the BH L_bol
      // will be included in counting it.)
      emission_weight = mass_stars [i];
      d[aux_pars_fields::mass_stars] = 0;
      d[aux_pars_fields::mass_metals_stars] = 0;
      d[aux_pars_fields::L_bol] = L_bol [i];
      d[aux_pars_fields::age_m] = 0;
      d[aux_pars_fields::age_l] = 0;
      d*= 1/emission_weight;
    }
    else {
      emission_weight = mass_stars [i];
      d[aux_pars_fields::mass_stars] = mass_stars [i];
      d[aux_pars_fields::mass_metals_stars] = mass_stars [i]*z[i];
      d[aux_pars_fields::L_bol] = L_bol [i];
      d[aux_pars_fields::age_m] = age [i]*mass_stars [i];
      d[aux_pars_fields::age_l] = age [i]*L_bol [i];
      d*= 1/emission_weight;
    }

    const vec3d pos(position(i,0),position(i,1),position(i,2));
    const vec3d vel(velocity(i,0),velocity(i,1),velocity(i,2));

    T_emitter* dd = new T_emitter (pos, vel, radius[i], d);

    // save variables for setup of emission_collection 
    pointers.push_back(dd);
    weights(i) = emission_weight;
  }

  // init probability distribution
  this->distr_.setup(pointers, weights);
}


/** Constructor loads aux particle data from the specified HDU. */
template <template<typename> class sampling_policy, 
	  typename rng_policy_type>
mcrx::aux_particles_emission<sampling_policy, rng_policy_type>::
aux_particles_emission(CCfits::ExtHDU& hdu) :
  emission_collection<generic_chromatic_policy<aux_pars_type>,
		      sampling_policy, rng_policy_type> ()
{
  fucked_up(hdu);
}


#endif
